/*
 *	Copyright 2023 Jan Pfeifer
 *
 *	Licensed under the Apache License, Version 2.0 (the "License");
 *	you may not use this file except in compliance with the License.
 *	You may obtain a copy of the License at
 *
 *	http://www.apache.org/licenses/LICENSE-2.0
 *
 *	Unless required by applicable law or agreed to in writing, software
 *	distributed under the License is distributed on an "AS IS" BASIS,
 *	WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *	See the License for the specific language governing permissions and
 *	limitations under the License.
 */

// Package train holds tools to help run a training loop.
//
// It provides various levels of tooling, from supporting training
// one step, or a full loop. But it should serve also as an example
// that users that needs more flexibility can start from to create
// their own training loops.
package train

import (
	"fmt"
	"io"
	"iter"
	"slices"

	"github.com/gomlx/gomlx/backends"
	"github.com/gomlx/gomlx/internal/exceptions"
	"github.com/gomlx/gomlx/pkg/core/distributed"
	"github.com/gomlx/gomlx/pkg/core/graph"
	"github.com/gomlx/gomlx/pkg/core/tensors"
	"github.com/gomlx/gomlx/pkg/ml/context"
	"github.com/gomlx/gomlx/pkg/ml/train/losses"
	"github.com/gomlx/gomlx/pkg/ml/train/metrics"
	"github.com/gomlx/gomlx/pkg/ml/train/optimizers"
	"github.com/pkg/errors"
)

const (
	// TrainerAbsoluteScope used for Context parameters related to the trainer.
	TrainerAbsoluteScope = context.ScopeSeparator + "trainer"

	// TrainerLossGraphParamKey is the graph params key that maps to the
	// node with the accumulated losses in the model
	TrainerLossGraphParamKey = "trainer_loss"

	// TrainerLossNoRegularizationGraphParamKey is the graph params key that maps to the
	// node with the loss (resulted from calling the lossFn passed to the trainer) without the regularization.
	//
	// If no lossFn was given to the trainer, this will map to nil.
	TrainerLossNoRegularizationGraphParamKey = "trainer_loss_no_regularization"

	// TrainerPerStepUpdateGraphFnParamKey is used by AddPerStepUpdateGraphFn.
	TrainerPerStepUpdateGraphFnParamKey = "trainer_per_step_update_graph_fn"
)

// Trainer is a helper object to orchestrate a training step and evaluation.
//
// Given the inputs and labels, it deals with executing a training step (TrainStep) and
// evaluation (EvalStep and Eval), calling the loss and optimizer and running metrics.
//
// See Loop for a flexible and extensible (different UIs) way to run this in a training loop.
type Trainer struct {
	backend   backends.Backend
	context   *context.Context
	deviceNum backends.DeviceNum
	modelFn   ModelFn
	lossFn    LossFn
	optimizer optimizers.Interface

	// Distribute execution:
	deviceAssignment []backends.DeviceNum // Optional.

	// maxExecutors to cache. It will fail after that. One executor is created
	// per `spec` value, so this is the same as the max number of different `spec`
	// values seen.
	maxExecutors              int
	inputsAndLabelsLenPerSpec map[any][2]int

	// Training data
	trainStepExecMap map[any]*context.Exec
	trainMetrics     []metrics.Interface

	// Accumulate gradients mode:
	accumulateGradients                bool
	accumulateGradientsSteps           int
	accumulateGradientsCurrentStep     int
	accumulateGradientsExecMap         map[any]*context.Exec
	accumulateGradientsAndApplyExecMap map[any]*context.Exec

	// Eval data
	evalStepExecMap map[any]*context.Exec
	evalMetrics     []metrics.Interface

	// BatchNormAverages data
	batchNormStepExecMap map[any]*context.Exec

	// onExecCreationHandlers are hooks called during executor creation.
	onExecCreationHandlers []OnExecFn
}

// ModelFn is a computation graph building function that takes as input a `spec` and a slice of `inputs`
// (even if just one) generated by a Dataset, and as output a slice (even if only one) of the `predictions`
// (or sometimes the logits).
//
// The `predictions` output by ModelFn is fed to a LossFn and to MetricsFn during training.
//
// Notice `spec` is opaque to train package, it's passed from the train.Dataset to the ModelFn, and its
// meaning is determined by the train.Dataset used. For static case (where data is always the same) it can simply be
// nil. Each value of `spec` is mapped to different computation graphs by the train.Trainer.
type ModelFn func(ctx *context.Context, spec any, inputs []*graph.Node) (predictions []*graph.Node)

// LossFn takes the output of ModelFn (called predictions, but it could be the logits),
// and the labels (coming out of Dataset.Yield()), and outputs the scalar loss, that can
// be used for training.
//
// For some types of self-supervised models for which there are no labels, the labels can be empty.
//
// Most of the predefined losses in package `gomlx/ml/train/losses` assume labels and predictions are
// both of length one. For multi-head models, it's very easy to write a small custom LossFn that splits
// the slice and send each label/prediction pair to a predefined loss.
//
// Interface is defined in the losses package.
type LossFn = losses.LossFn

// DefaultMaxExecutors used for Trainer objects. Each different `spec` value from a Dataset triggers
// the creation of a new executor.
//
// If using AccumulateGradients, there will be 2 graphs per shape of input data: one to accumulate the gradients
// and another to accumulate and apply the gradients.
var DefaultMaxExecutors = 50

// GraphType can be TrainGraph or EvalGraph, when there needs to be a distinction.
type GraphType int

const (
	TrainType GraphType = iota
	EvalType
	BatchNormAveragesType
)

// NewTrainer constructs a trainer that can be used for training steps and evaluation. It also creates a new Context
// for model, which will hold the variables, hyperparameters, and other information. It can be changed by the user.
//
// Its arguments are:
//
//   - backend needed to create and compile computation graphs.
//
//   - ctx (will) hold the variables, hyperparameters, and related information for the model.
//
//   - modelFn builds the graph that transforms inputs into predictions (or logits).
//
//   - lossFn takes the predictions (the output of modelFn) and the labels and outputs the loss. If the
//     returned loss is not a scalar, it will be ReduceAllMean to a scalar.
//     There are several standard losses available in gomlx/ml/train/losses package.
//     They can simply be used as is, or called by arbitrary custom losses.
//     It can also be set to nil, if one is providing loss terms with `AddLoss` -- e.g.: for unsupervised training.
//
//   - optimizer (e.g: optimizers.StochasticGradientDescent) is the methodology to improve the model variables (aka.
//     parameters or weights) to minimize the loss (the output of lossFn), typically using gradient descent.
//
//   - trainMetrics are output by trainer.TrainStep after each step. Here it's recommended to use moving average
//     types of metrics, since the model is changing so a mean wouldn't make sense. The mean loss of the batch and
//     a moving average of the loss is always included (the first two) by default. It's ok to be empty (nil).
//
//   - evalMetrics are output by trainer.EvalStep and trainer.Eval. Here it's recommend to use mean metrics, since the model
//     is presumably frozen, and it sees each example exactly once. The mean of the loss of the dataset is always provided
//     as the first metric. It's ok to be empty (nil).
func NewTrainer(backend backends.Backend, ctx *context.Context,
	modelFn ModelFn, lossFn LossFn, optimizer optimizers.Interface,
	trainMetrics, evalMetrics []metrics.Interface) *Trainer {

	r := &Trainer{
		backend:   backend,
		context:   ctx,
		deviceNum: 0,
		modelFn:   modelFn,
		lossFn:    lossFn,
		optimizer: optimizer,

		maxExecutors:              DefaultMaxExecutors,
		inputsAndLabelsLenPerSpec: make(map[any][2]int),
		trainStepExecMap:          make(map[any]*context.Exec),
		evalStepExecMap:           make(map[any]*context.Exec),
		batchNormStepExecMap:      make(map[any]*context.Exec),

		accumulateGradientsExecMap:         make(map[any]*context.Exec),
		accumulateGradientsAndApplyExecMap: make(map[any]*context.Exec),
	}

	// Delete variables that should forcefully be reinitialized every time the model is retrained.
	optScope := ctx.In(optimizers.Scope).Scope()
	err := ctx.DeleteVariable(optScope, optimizers.ParamLearningRate)
	if err != nil {
		panic(err)
	}

	// Create a context executor for TrainStep. Automatically include batch loss and moving average loss metrics.
	numMetrics := len(trainMetrics) + 3
	lossAndMetrics := make([]metrics.Interface, 0, numMetrics)
	batchLossFn := func(_ *context.Context, labels, predictions []*graph.Node) *graph.Node {
		// Assume lossVar has already been set.
		g := predictions[0].Graph()
		loss := GetLosses(ctx, g)
		if loss == nil {

			return graph.ScalarZero(g, predictions[0].DType())
		}
		return loss
	}
	lossNoRegularizationFn := func(ctx *context.Context, labels, predictions []*graph.Node) *graph.Node {
		g := predictions[0].Graph()
		loss := GetLossNoRegularization(ctx, g)
		if loss == nil {
			return graph.ScalarZero(g, predictions[0].DType())
		}
		return loss
	}
	lossAndMetrics = append(
		lossAndMetrics,
		metrics.NewBaseMetric("Batch Loss+Regularization", "loss+", metrics.LossMetricType, batchLossFn, nil),
	)
	lossAndMetrics = append(
		lossAndMetrics,
		metrics.NewExponentialMovingAverageMetric(
			"Moving Average Loss+Regularization",
			"~loss+",
			metrics.LossMetricType,
			batchLossFn,
			nil,
			0.01,
		),
	)
	if r.lossFn != nil {
		lossAndMetrics = append(lossAndMetrics,
			metrics.NewExponentialMovingAverageMetric("Moving Average Loss", "~loss", metrics.LossMetricType,
				lossNoRegularizationFn, nil, 0.01))
	}

	lossAndMetrics = append(lossAndMetrics, trainMetrics...)
	r.trainMetrics = lossAndMetrics

	// Create a context executor for EvalStep. Automatically include the mean loss metric as the first eval metric.
	numMetrics = len(evalMetrics) + 2
	lossAndMetrics = make([]metrics.Interface, 0, numMetrics)
	lossAndMetrics = append(
		lossAndMetrics,
		metrics.NewMeanMetric("Mean Loss+Regularization", "#loss+", metrics.LossMetricType, batchLossFn, nil),
	)
	if r.lossFn != nil {
		lossAndMetrics = append(lossAndMetrics, metrics.NewMeanMetric("Mean Loss", "#loss", metrics.LossMetricType,
			lossNoRegularizationFn, nil))
	}
	lossAndMetrics = append(lossAndMetrics, evalMetrics...)
	r.evalMetrics = lossAndMetrics
	return r
}

// WithDeviceAssignemnt sets the backend device assignment to use.
//
// This works for both distributed and normal (single device) training. The later uses only one device where to
// execute the training/evaluation steps.
//
// A nil value (the default) is valid, in which case it uses the backend's default, which is usually simply the
// sequential devices starting from 0.
func (r *Trainer) WithDeviceAssignment(deviceAssignment ...backends.DeviceNum) *Trainer {
	if deviceAssignment == nil {
		r.deviceAssignment = nil
		return r
	}
	r.deviceAssignment = slices.Clone(deviceAssignment)
	return r
}

// iterateExecs returns an iterator over all executors maintained by the Trainer.
func (r *Trainer) iterateExecs() iter.Seq[*context.Exec] {
	return func(yield func(*context.Exec) bool) {
		for _, exec := range r.trainStepExecMap {
			if !yield(exec) {
				return
			}
		}
		for _, exec := range r.evalStepExecMap {
			if !yield(exec) {
				return
			}
		}
		for _, exec := range r.batchNormStepExecMap {
			if !yield(exec) {
				return
			}
		}
		for _, exec := range r.accumulateGradientsExecMap {
			if !yield(exec) {
				return
			}
		}
		for _, exec := range r.accumulateGradientsAndApplyExecMap {
			if !yield(exec) {
				return
			}
		}
	}
}

// Context returns the current Context. See SetContext to change it.
func (r *Trainer) Context() *context.Context {
	return r.context
}

// SetContext associates the given Context to the trainer. Should
// be called before any calls to Train or Evaluate.
// Notice that after the first time context is used to build a graph,
// it is set to Reuse. If the Context variables were already created,
// it should be marked with Context.Reuse.
// It returns a reference to itself so calls can be cascaded.
func (r *Trainer) SetContext(ctx *context.Context) *Trainer {
	r.context = ctx
	for exec := range r.iterateExecs() {
		exec.SetContext(ctx)
	}
	return r
}

// WithMaxExecutors configure the Trainer to allow these many different executors to be created before failing.
//
// One executor is created per each different shape of input, times one for training, and one for evaluation (if
// Trainer.Eval is being used).
// The default is DefaultMaxExecutors.
//
// It returns the Trainer itself.
func (r *Trainer) WithMaxExecutors(maxExecutors int) *Trainer {
	r.maxExecutors = maxExecutors
	return r
}

// TrainMetrics returns the train metrics objects (not the actual values just the objects
// that implement them).
func (r *Trainer) TrainMetrics() []metrics.Interface { return r.trainMetrics }

// EvalMetrics returns the eval metrics objects: not the actual metric values, just the objects
// that implement them, holds their name and a default pretty-printing function.
func (r *Trainer) EvalMetrics() []metrics.Interface { return r.evalMetrics }

// createExecutor (train or eval) for the given spec. Returns an error if it failed for
// any reason, including exceeding maxExecutors.
func (r *Trainer) createExecutor(spec any, inputsLen, labelsLen int,
	graphFn func(spec any, ctx *context.Context, inputs, labels []*graph.Node) (metrics []*graph.Node)) (
	*context.Exec, error) {
	numExecs := len(r.trainStepExecMap) + len(r.evalStepExecMap) +
		len(r.accumulateGradientsExecMap) + len(r.accumulateGradientsAndApplyExecMap) +
		len(r.batchNormStepExecMap)
	if numExecs > r.maxExecutors {
		return nil, errors.Errorf("Max number of executors reached: one is created for each "+
			"different value of `spec` returned by Dataset, triggering a different JIT-compiled "+
			"computation graph. Probably you want to limit the number of different datasets configuration "+
			"(spec) or shapes supported, or increase the allowed number of executors (see Train.WithMaxExecutors) "+
			" if this is what you want. Value of spec passed at this iteration: %+v", spec)
	}
	if numExecs > 0 {
		r.context = r.context.Checked(false) // Only check for duplicate variables at the first graph creation.
	}
	r.inputsAndLabelsLenPerSpec[spec] = [2]int{inputsLen, labelsLen}
	trainerName := "Trainer"
	if _, found := spec.(fmt.Stringer); found {
		trainerName = fmt.Sprintf("Trainer: spec=%s", spec)
	}
	exec, err := context.NewExec(r.backend, r.context,
		func(ctx *context.Context, inputsAndLabels []*graph.Node) (metrics []*graph.Node) {
			inputs := inputsAndLabels[:inputsLen]
			labels := inputsAndLabels[inputsLen:]
			return graphFn(spec, ctx, inputs, labels)
		})
	if r.deviceAssignment != nil {
		exec = exec.WithDeviceAssignment(r.deviceAssignment)
	}
	if err != nil {
		return nil, errors.WithMessagef(err, "failed to create executor for spec %+v", spec)
	}
	return exec.WithName(trainerName), nil
}

// lossFnScalarLoss calls `r.lossFn` and [ReduceAllMean] to a scalar.
// It assumes `r.lossFn != nil`.
func (r *Trainer) lossFnScalarLoss(_ *context.Context, labels, predictions []*graph.Node) *graph.Node {
	loss := r.lossFn(labels, predictions)
	if !loss.Shape().IsScalar() {
		loss = graph.ReduceAllMean(loss)
	}
	return loss
}

// trainStepGraph builds the graph to train one step. It is called by the context executor (`r.trainStepExecMap`)
// every time a graph needs to be built (typically for new batch sizes).
func (r *Trainer) trainStepGraph(spec any, ctx *context.Context, inputs, labels []*graph.Node) (metrics []*graph.Node) {
	g := inputs[0].Graph()
	ctx.SetTraining(g, true) // Some layers behave differently if in training.

	// AddLoss generated by the given lossFn.
	predictions := r.modelFn(ctx, spec, inputs)
	if r.lossFn != nil {
		baseLoss := r.lossFnScalarLoss(ctx, labels, predictions)
		SetLossNoRegularization(ctx, baseLoss)
		AddLoss(ctx, baseLoss)
	}

	// Store total loss as a variable, so it can be used by metrics.
	loss := GetLosses(ctx, g)
	if loss == nil {
		exceptions.Panicf(
			"no loss function defined (or it returned nil), and no loss set with AddLoss(), there is nothing to optimize!?",
		)
	}

	// Optimizer: it will create graph for gradient.
	r.optimizer.UpdateGraph(ctx, g, loss)

	// Execute registered ContextGraphFn hooks for current graph.
	ExecPerStepUpdateGraphFn(ctx, g)

	// Metrics updates. They include: batch loss, exponential moving average of the batch loss.
	if len(predictions) == 0 {
		// We create a zero prediction (same dtype as loss), because the metrics require something.
		predictions = []*graph.Node{graph.ScalarZero(g, loss.DType())}
	}
	metrics = r.metricsUpdatesGraph(ctx, labels, predictions, r.trainMetrics)
	return
}

// callGraphFn for TrainStep or EvalStep makes sure the builds the arguments for execution,
// plus do standard checks on inputs and labels.
func (r *Trainer) callGraphFn(
	graphFn func(spec any, ctx *context.Context, inputs, labels []*graph.Node) (metrics []*graph.Node),
	graphType GraphType,
	execMap map[any]*context.Exec,
	spec any,
	inputs, labels []*tensors.Tensor,
) (metrics []*tensors.Tensor, err error) {
	if len(inputs) == 0 {
		return nil, errors.New("there are no inputs, at least one is required")
	}
	if lengths, found := r.inputsAndLabelsLenPerSpec[spec]; found {
		if len(inputs) != lengths[0] || len(labels) != lengths[1] {
			return nil, errors.Errorf("dataset yields inputs (%d) and labels (%d) with lengths different "+
				"than with previous call (%d and %d) for the given spec %+v", len(inputs), len(labels),
				lengths[0], lengths[1], spec)
		}
	}
	for ii, input := range inputs {
		if input == nil {
			return nil, errors.Errorf("inputs[%d] is nil!?", ii)
		}
	}

	// Create arguments as []any and run trainStepExecMap.Exec().
	numParams := len(inputs) + len(labels)
	inputsAndLabels := make([]any, 0, numParams)
	for _, t := range inputs {
		inputsAndLabels = append(inputsAndLabels, t)
	}
	for _, t := range labels {
		inputsAndLabels = append(inputsAndLabels, t)
	}

	// Get the executor for the graphType and input spec.
	exec, found := execMap[spec]
	if !found {
		exec, err = r.createExecutor(spec, len(inputs), len(labels), graphFn)
		if err != nil {
			return nil, err
		}
		execMap[spec] = exec
		for _, handler := range r.onExecCreationHandlers {
			handler(exec, graphType) // Call the handler for training.
		}
	}

	// Collect metrics:
	metrics, err = exec.Exec(inputsAndLabels...)
	if err != nil {
		return nil, errors.WithMessage(err, "failed to execute train/eval step")
	}
	if len(metrics) == 0 {
		return nil, errors.New("no metrics calculate metric in step")
	}
	return metrics, nil
}

// distributedCallGraphFn for DistributedTrainStep or DistributedEvalStep makes sure the builds the arguments for
// execution, plus do standard checks on inputs and labels.
//
// Notice that the returned metrics are not distributed.
func (r *Trainer) distributedCallGraphFn(
	strategy distributed.Strategy,
	deviceAssignment []backends.DeviceNum,
	graphFn func(spec any, ctx *context.Context, inputs, labels []*graph.Node) (metrics []*graph.Node),
	graphType GraphType,
	execMap map[any]*context.Exec,
	spec any,
	inputs, labels []*distributed.Tensor,
) (metrics []*tensors.Tensor, err error) {
	if len(inputs) == 0 {
		return nil, errors.New("there are no inputs, at least one is required")
	}
	if lengths, found := r.inputsAndLabelsLenPerSpec[spec]; found {
		if len(inputs) != lengths[0] || len(labels) != lengths[1] {
			return nil, errors.Errorf("dataset yields inputs (%d) and labels (%d) with lengths different "+
				"than with previous call (%d and %d) for the given spec %+v", len(inputs), len(labels),
				lengths[0], lengths[1], spec)
		}
	}
	for ii, input := range inputs {
		if input == nil {
			return nil, errors.Errorf("inputs[%d] is nil!?", ii)
		}
	}

	// Create arguments as []any and run trainStepExecMap.Exec().
	numParams := len(inputs) + len(labels)
	inputsAndLabels := make([]any, 0, numParams)
	for _, t := range inputs {
		inputsAndLabels = append(inputsAndLabels, t)
	}
	for _, t := range labels {
		inputsAndLabels = append(inputsAndLabels, t)
	}

	// Get the executor for the graphType and input spec.
	exec, found := execMap[spec]
	if !found {
		// Set up of the executor and compilation of the computation graph: this happens only
		// once for each inputs/labels shape combination.
		exec, err = r.createExecutor(spec, len(inputs), len(labels), graphFn)
		if err != nil {
			return nil, err
		}

		// Find meshes and aggregate all input ShardingSpec.
		var meshes []*distributed.DeviceMesh
		inputShardingSpecs := make([]*distributed.ShardingSpec, 0, numParams)
		for _, inputAny := range inputsAndLabels {
			input := inputAny.(*distributed.Tensor)
			shardingSpec := input.ShardingSpec()
			mesh := shardingSpec.Mesh
			if slices.Index(meshes, mesh) == -1 {
				meshes = append(meshes, mesh)
			}
			inputShardingSpecs = append(inputShardingSpecs, shardingSpec)
		}
		if len(meshes) == 0 {
			return nil, errors.New("missing a mesh definition from the inputs/labels (likely from a DistributedDataset)")
		}
		replicatedSpec := distributed.NewReplicatedShardingSpec(meshes[0])

		// Setup exec for distributed execution:
		switch strategy {
		case distributed.AutoSharding:
			exec = exec.AutoSharding(meshes...).
				WithInputShardingSpecs(inputShardingSpecs...).
				WithOutputShardingSpecs(replicatedSpec).
				WithDeviceAssignment(deviceAssignment)
		case distributed.SPMD:
			if len(meshes) > 1 {
				return nil, errors.Errorf("the distributed strategy SPMD only accepts one mesh for sharding the inputs, got %d",
					len(meshes))
			}
			exec = exec.SPMD(meshes[0]).
				WithInputShardingSpecs(inputShardingSpecs...).
				WithOutputShardingSpecs(replicatedSpec).
				WithDeviceAssignment(deviceAssignment)
		case distributed.None:
			return nil, errors.New("cannot do a distributed train/eval step if the strategy is None -- try distributed.AutoSharding?")
		}

		// Register new executor.
		execMap[spec] = exec
		for _, handler := range r.onExecCreationHandlers {
			handler(exec, graphType) // Call the handler for training.
		}
	}

	// Execute, it returns the distributed metrics:
	distributedMetrics, err := exec.DistributedExec(inputsAndLabels...)
	if err != nil {
		return nil, errors.WithMessage(err, "failed to execute train/eval step")
	}
	if len(distributedMetrics) == 0 {
		return nil, errors.New("no metrics calculate metric in step")
	}

	// Collect the metrics from the first device:
	metrics = make([]*tensors.Tensor, len(distributedMetrics))
	for i, distributedMetric := range distributedMetrics {
		metrics[i], err = distributedMetric.Shards()[0].LocalClone()
		if err != nil {
			return nil, errors.WithMessagef(err, "failed to clone metric %d from distributed metric", i)
		}
		err = distributedMetric.Finalize()
		if err != nil {
			return nil, errors.WithMessagef(err, "failed to finalize distributed metric %d", i)
		}
	}
	return metrics, nil
}

// ResetComputationGraphs can be used during training in between steps to force the recreation of the computation graphs.
//
// This is used if, for instance, the training has schedules where hyperparameters change (some variables are frozen)
// the computation graph needs to be updated accordingly.
//
// see Loop.OnStep to schedule
func (r *Trainer) ResetComputationGraphs() {
	for _, execMap := range []map[any]*context.Exec{r.trainStepExecMap, r.evalStepExecMap, r.batchNormStepExecMap,
		r.accumulateGradientsExecMap, r.accumulateGradientsAndApplyExecMap} {
		for _, e := range execMap {
			e.Finalize()
		}
		clear(execMap)
	}
}

// metricsUpdatesGraph creates the graph for a set of metrics.
func (r *Trainer) metricsUpdatesGraph(ctx *context.Context, labels, predictions []*graph.Node,
	metricsObjects []metrics.Interface) (metrics []*graph.Node) {
	numMetrics := len(metricsObjects)
	metrics = make([]*graph.Node, 0, numMetrics)
	ctxUnchecked := ctx.Checked(false)
	for _, metric := range metricsObjects {
		metricResult := metric.UpdateGraph(ctxUnchecked, labels, predictions)
		metrics = append(metrics, metricResult)
	}
	return
}

// TrainStep runs one step and returns the metrics.
//
// All arguments usually come from `Dataset.Yield`, see a more detailed description there. In short:
//
//   - spec: provided by the dataset. Often just nil. Each value will trigger the creation
//     of different computation graphs. Normally static values (for the dataset) used to describe
//     the inputs. See longer discussion in `train.Dataset`.
//   - inputs: always a slice, even though it's common to have only one input tensor in the slice.
//     There must be always at least one input. For each `spec` value, the number of inputs and labels
//     must remain constant. It will return an error otherwise.
//   - labels: also always a slice, even if commonly with only one tensor.
//
// It returns a slice of metrics, that includes (the first two) the batch loss, and the moving exponential average
// of the batch loss, plus the other `trainMetrics` configured during the creation of the Trainer.
func (r *Trainer) TrainStep(spec any, inputs, labels []*tensors.Tensor) (metrics []*tensors.Tensor, err error) {
	if r.accumulateGradients {
		// Version that accumulate gradients.
		return r.trainStepWithAccumulateGradients(spec, inputs, labels)
	}
	return r.callGraphFn(r.trainStepGraph, TrainType, r.trainStepExecMap, spec, inputs, labels)
}

// DistributedTrainStep runs one step and returns the metrics, in a distributed fashion.
//
// The strategy and device assignment is only used when a new executor is built, that is,
// only when the dataset spec changes.
//
// Otherwise, it behaves just like TrainStep.
func (r *Trainer) DistributedTrainStep(strategy distributed.Strategy, deviceAssignment []backends.DeviceNum,
	spec any, inputs, labels []*distributed.Tensor) (metrics []*tensors.Tensor, err error) {
	if r.accumulateGradients {
		// Version that accumulate gradients.
		return nil, errors.New("distributed training with gradient accumulation not implemented yet")
	}
	return r.distributedCallGraphFn(strategy, deviceAssignment, r.trainStepGraph, TrainType, r.trainStepExecMap, spec, inputs, labels)
}

// evalStepGraph builds the graph to eval one step. It is called by the context executor (`r.evalStepExecMap`)
// every time a graph needs to be built (typically for new batch sizes).
// inputsAndLabel[:-1] are the inputs, and inputsAndLabel[-1] is the labels batch.
func (r *Trainer) evalStepGraph(spec any, ctx *context.Context, inputs, labels []*graph.Node) (metrics []*graph.Node) {
	g := inputs[0].Graph()
	ctx.SetTraining(g, false) // Some layers behave differently in train/eval.

	predictions := r.modelFn(ctx, spec, inputs)
	if r.lossFn != nil {
		baseLoss := r.lossFnScalarLoss(ctx, labels, predictions)
		SetLossNoRegularization(ctx, baseLoss)
		AddLoss(ctx, baseLoss)
	}

	// Get metrics and updates
	metrics = r.metricsUpdatesGraph(ctx, labels, predictions, r.evalMetrics)
	return
}

// ResetTrainMetrics call Metrics.Reset on all train metrics. Usually called before a training session.
func (r *Trainer) ResetTrainMetrics() error {
	for _, metric := range r.trainMetrics {
		err := exceptions.TryCatch[error](func() { metric.Reset(r.context.Checked(false)) })
		if err != nil {
			panic(errors.WithMessagef(err, "Eval() failed to reset metric %q", metric.Name()))
		}
	}
	return nil
}

// EvalStep runs one eval step and returns the metrics, the first one being the mean loss.
//
// The parameters are the output of a Dataset.Yield call. The same as TrainStep.
//
// It returns the current value for the registered eval metrics.
func (r *Trainer) EvalStep(spec any, inputs, labels []*tensors.Tensor) (metrics []*tensors.Tensor, err error) {
	return r.callGraphFn(r.evalStepGraph, EvalType, r.evalStepExecMap, spec, inputs, labels)
}

// DistributedEvalStep runs one eval step and returns the metrics, in a distributed fashion.
//
// The strategy and device assignment is only used when a new executor is built, that is,
// only when the dataset spec changes.
//
// Otherwise, it behaves just like EvalStep.
func (r *Trainer) DistributedEvalStep(strategy distributed.Strategy, deviceAssignment []backends.DeviceNum,
	spec any, inputs, labels []*distributed.Tensor) (metrics []*tensors.Tensor, err error) {
	return r.distributedCallGraphFn(strategy, deviceAssignment, r.evalStepGraph, EvalType, r.evalStepExecMap, spec, inputs, labels)
}

// resetEvalMetrics call Metrics.Reset on all eval metrics.
func (r *Trainer) resetEvalMetrics() error {
	for _, metric := range r.evalMetrics {
		err := exceptions.TryCatch[error](func() { metric.Reset(r.context) })
		if err != nil {
			return errors.WithMessagef(err, "Eval() failed to reset metric %q", metric.Name())
		}
	}
	return nil
}

// Eval returns the computation of loss and metrics over the given dataset. The dataset
// has to be finite (yield io.EOF at the end). The function will reset the dataset
// at the start.
//
// If the dataset is a DistributedDataset, it will be evaluated in a distribute fashion, see DistributedEval.
//
// Note: inputs and labels yielded by the dataset are immediately finalized (freed) after use.
func (r *Trainer) Eval(ds Dataset) (lossAndMetrics []*tensors.Tensor, err error) {
	if distributedDS, ok := ds.(DistributedDataset); ok {
		return r.DistributedEval(distributedDS)
	}
	ds.Reset()
	err = r.resetEvalMetrics()
	if err != nil {
		return nil, err
	}
	count := 0
	finalizeInputs := finalizeYieldedTensors(ds)

	// Check for metrics with Go updates: these are update functions not written as a computation graph.
	goUpdateFns := make([]metrics.UpdateGo, len(r.evalMetrics))
	for ii, metric := range r.evalMetrics {
		if fn, ok := metric.(metrics.UpdateGo); ok {
			goUpdateFns[ii] = fn
		}
	}

	// Loop over dataset:
	for {
		spec, inputs, labels, err := ds.Yield()
		if err == io.EOF {
			break
		}
		if err != nil {
			return nil, errors.Wrap(err, "dataset returned an error during Eval")
		}
		count++
		// Early free (not wait for the GC) of the results of previous batch.
		for _, t := range lossAndMetrics {
			err := t.FinalizeAll()
			if err != nil {
				return nil, errors.WithMessagef(err,
					"finalizing loss and metrics tensor of dataset %q after use in a distributed eval step",
					ds.Name())
			}
		}

		lossAndMetrics, err = r.EvalStep(spec, inputs, labels)
		if err != nil {
			return nil, errors.WithMessage(err, "EvalStep failed")
		}
		for i, goUpdateFn := range goUpdateFns {
			if goUpdateFn != nil {
				goUpdateFn.UpdateGo(lossAndMetrics[i])
			}
		}

		// Free inputs and labels after usage.
		if finalizeInputs {
			for sliceIdx, slice := range [][]*tensors.Tensor{inputs, labels} {
				for i, t := range slice {
					err := t.FinalizeAll()
					if err != nil {
						return nil, errors.WithMessagef(
							err, "finalizing %s tensor #%d of dataset %q after use in a distributed eval step",
							yieldInputTypeNames[sliceIdx], i, ds.Name())
					}
				}
			}
		}
	}
	if count == 0 {
		return nil, errors.New("evaluation dataset yielded no batches, no data to evaluate")
	}

	// Read out the go-generate metrics:
	for i, goUpdateFn := range goUpdateFns {
		if goUpdateFn != nil {
			lossAndMetrics[i] = goUpdateFn.ReadGo()
		}
	}
	return lossAndMetrics, nil
}

// DistributedEval returns the computation of loss and metrics over the given distributed dataset.
// The dataset has to be finite (yield io.EOF at the end). The function will reset the dataset at the start.
//
// It reads the strategy and device assignment from the DistributedDataset and uses DistributedEvalStep
// for each batch.
//
// Note: inputs and labels yielded by the dataset are immediately finalized (freed) after use.
func (r *Trainer) DistributedEval(ds DistributedDataset) (lossAndMetrics []*tensors.Tensor, err error) {
	ds.Reset()
	err = r.resetEvalMetrics()
	if err != nil {
		return nil, err
	}
	count := 0
	finalizeInputs := finalizeYieldedTensors(ds)
	strategy := ds.Strategy()
	deviceAssignment := ds.DeviceAssignment()

	// Check for metrics with Go updates: these are update functions not written as a computation graph.
	goUpdateFns := make([]metrics.UpdateGo, len(r.evalMetrics))
	for ii, metric := range r.evalMetrics {
		if fn, ok := metric.(metrics.UpdateGo); ok {
			goUpdateFns[ii] = fn
		}
	}

	// Loop over dataset:
	for {
		spec, inputs, labels, err := ds.DistributedYield()
		if err == io.EOF {
			break
		}
		if err != nil {
			return nil, errors.Wrap(err, "dataset returned an error during DistributedEval")
		}
		count++
		// Early free (not wait for the GC) of the results of previous batch.
		for _, t := range lossAndMetrics {
			err := t.FinalizeAll()
			if err != nil {
				return nil, errors.WithMessagef(err,
					"finalizing loss and metrics tensor of dataset %q after use in a distributed eval step",
					ds.Name())
			}
		}

		lossAndMetrics, err = r.DistributedEvalStep(strategy, deviceAssignment, spec, inputs, labels)
		if err != nil {
			return nil, errors.WithMessage(err, "DistributedEvalStep failed")
		}
		for i, goUpdateFn := range goUpdateFns {
			if goUpdateFn != nil {
				goUpdateFn.UpdateGo(lossAndMetrics[i])
			}
		}

		// Free inputs and labels after usage.
		if finalizeInputs {
			for sliceIdx, slice := range [][]*distributed.Tensor{inputs, labels} {
				for i, t := range slice {
					err := t.FinalizeAll()
					if err != nil {
						return nil, errors.WithMessagef(
							err, "finalizing %s distributed tensor #%d of dataset %q after use in a distributed eval step",
							yieldInputTypeNames[sliceIdx], i, ds.Name())
					}
				}
			}
		}
	}
	if count == 0 {
		return nil, errors.New("evaluation dataset yielded no batches, no data to evaluate")
	}

	// Read out the go-generated metrics:
	for i, goUpdateFn := range goUpdateFns {
		if goUpdateFn != nil {
			lossAndMetrics[i] = goUpdateFn.ReadGo()
		}
	}
	return lossAndMetrics, nil
}

// Metrics return list of registered eval metrics, including the loss metric that is added automatically.
func (r *Trainer) Metrics() []metrics.Interface {
	return r.evalMetrics
}

// GlobalStep is an alias for optimizers.GetGlobalStep using Trainer.Context().
func (r *Trainer) GlobalStep() int64 {
	return optimizers.GetGlobalStep(r.context)
}

// OnExecFn is a handler that can be called when executors are created.
// See Train.OnExecCreation.
type OnExecFn func(exec *context.Exec, graphType GraphType)

// OnExecCreation registers a handler to be called each time an executor (`context.Exec`) is created by the trainer.
// Different executors are created for training, eval, BatchNormalization, accumulatation of gradients (`train` reflect that), and for different
// `spec` values received from the Dataset.
// The `handler` is also given the type (TrainGraph or EvalGraph) the executor is created for.
func (r *Trainer) OnExecCreation(handler OnExecFn) {
	r.onExecCreationHandlers = append(r.onExecCreationHandlers, handler)
}

// AddLoss adds the given scalar loss (if it is not scalar, it will be reduced with ReduceAllMean)
// to the context's Params. This is the loss used by the trainer to optimize the model.
//
// This function can be called multiple times and the loss is accumulated.
//
// If `loss` is not scalar (often one doesn't reduce the batch axis), it is automatically reduced with `graph.ReduceAllMean`.
//
// If you are only providing loss terms with AddLoss, you can pass nil as the `lossFn` parameter to the Trainer.
func AddLoss(ctx *context.Context, loss *graph.Node) {
	g := loss.Graph()
	ctxTrainer := ctx.InAbsPath(TrainerAbsoluteScope)
	if !loss.Shape().IsScalar() {
		loss = graph.ReduceAllMean(loss)
	}

	currentLoss, found := ctxTrainer.GetGraphParam(g, TrainerLossGraphParamKey)
	if found {
		loss = graph.Add(loss, currentLoss.(*graph.Node))
	}
	ctxTrainer.SetGraphParam(g, TrainerLossGraphParamKey, loss)
}

// GetLosses returns the sum of all loss terms added with AddLoss(), or nil if none was set.
//
// Usually this is used by the trainer after all losses are accounted for. But can be used
// by arbitrary modeling functions. In particular, after the optimizer update, see AddPerStepUpdateGraphFn.
func GetLosses(ctx *context.Context, g *graph.Graph) (loss *graph.Node) {
	ctxTrainer := ctx.InAbsPath(TrainerAbsoluteScope)
	lossAny, _ := ctxTrainer.GetGraphParam(g, TrainerLossGraphParamKey)
	return lossAny.(*graph.Node)
}

// SetLossNoRegularization saves a reference to model loss calculated by the lossFn used in training.
// This doesn't include regularization losses.
//
// This can be used by metrics that need to know the loss without regularization.
func SetLossNoRegularization(ctx *context.Context, loss *graph.Node) {
	ctxTrainer := ctx.InAbsPath(TrainerAbsoluteScope)
	ctxTrainer.SetGraphParam(loss.Graph(), TrainerLossNoRegularizationGraphParamKey, loss)
}

// GetLossNoRegularization returns the loss of the model not including regularization (usually added with
// AddLoss), if it has been set with SetLossNoRegularization.
func GetLossNoRegularization(ctx *context.Context, g *graph.Graph) (loss *graph.Node) {
	ctxTrainer := ctx.InAbsPath(TrainerAbsoluteScope)
	lossAny, found := ctxTrainer.GetGraphParam(g, TrainerLossNoRegularizationGraphParamKey)
	if !found {
		return nil
	}
	return lossAny.(*graph.Node)
}

// ContextGraphFn is a generic graph building function.
type ContextGraphFn func(ctx *context.Context, g *graph.Graph)

// AddPerStepUpdateGraphFn registers the given function fn to be executed at every training step, after optimizer
// updates the variables with the gradient. fn is called with the context set to the same scope it was registered with.
//
// This allows one for instance to implement variable constraints.
//
// There can be one ContextGraphFn registered per scope per graph. If one wants to register more than one
// such functions, use a different context scopes.
//
// One thing to observe: this is executed after the optimizer and therefore also after the loss is calculated.
// Any changes made to the model weights won't be reflected on the loss returned by the training step.
// Nor most of the metrics: the metrics are updated after this hook, but they typically use the predictions that
// were also generated earlier in the training.
//
// If you are writing a custom "TrainStep" function, you need to call ExecPerStepUpdateGraphFn after
// Optimizer.Update (or your custom updates). The Trainer does that for you already.
func AddPerStepUpdateGraphFn(ctx *context.Context, g *graph.Graph, fn ContextGraphFn) {
	ctx.SetGraphParam(g, TrainerPerStepUpdateGraphFnParamKey, fn)
}

// ExecPerStepUpdateGraphFn executes all registered "per-step update functions" registered with
// AddPerStepUpdateGraphFn.
//
// This should be called by the "TrainStep" function of a trainer, just after calling the
// Optimizer.Update method.
// If you are using the Trainer, it already does that for you. But if you are writing you own train step
// function, you may want to call this.
func ExecPerStepUpdateGraphFn(ctx *context.Context, g *graph.Graph) {
	ctx.EnumerateGraphParams(g, func(scope string, key string, value any) {
		if key != TrainerPerStepUpdateGraphFnParamKey {
			return
		}
		if fn, ok := value.(ContextGraphFn); ok {
			fn(ctx.InAbsPath(scope), g)
		}
	})

}
